# This workflow will:
# - Create a new Github release
# - Build wheels for supported architectures
# - Deploy the wheels to the Github release
# - Release the static code to PyPi
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries

name: Build wheels and deploy

on:
  create:
    tags:
      - v*

jobs:

  setup_release:
    name: Create Release
    runs-on: ubuntu-latest
    steps:
      - name: Get the tag version
        id: extract_branch
        run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
        shell: bash

      - name: Create Release
        id: create_release
        uses: actions/create-release@v1
        env:
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
        with:
          tag_name: ${{ steps.extract_branch.outputs.branch }}
          release_name: ${{ steps.extract_branch.outputs.branch }}

  build_wheels:
    name: Build Wheel
    needs: setup_release
    runs-on: ${{ matrix.os }}

    strategy:
      fail-fast: false
      matrix:
          # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
          # manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
          os: [ubuntu-20.04]
          python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
          torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240105']
          cuda-version: ['11.8.0', '12.2.2']
          # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
          # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
          # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
          # when building without C++11 ABI and using it on nvcr images.
          cxx11_abi: ['FALSE', 'TRUE']
          exclude:
            # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
            # Pytorch <= 1.12 does not support Python 3.11
            - torch-version: '1.12.1'
              python-version: '3.11'
            # Pytorch >= 2.0 only supports Python >= 3.8
            - torch-version: '2.0.1'
              python-version: '3.7'
            - torch-version: '2.1.2'
              python-version: '3.7'
            - torch-version: '2.2.0'
              python-version: '3.7'
            - torch-version: '2.3.0.dev20240105'
              python-version: '3.7'
            # Pytorch <= 2.0 only supports CUDA <= 11.8
            - torch-version: '1.12.1'
              cuda-version: '12.2.2'
            - torch-version: '1.13.1'
              cuda-version: '12.2.2'
            - torch-version: '2.0.1'
              cuda-version: '12.2.2'

    steps:
      - name: Checkout
        uses: actions/checkout@v3

      - name: Set up Python
        uses: actions/setup-python@v4
        with:
          python-version: ${{ matrix.python-version }}

      - name: Set CUDA and PyTorch versions
        run: |
          echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
          echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
          echo "MATRIX_PYTHON_VERSION=$(echo ${{ matrix.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV

      - name: Free up disk space
        if: ${{ runner.os == 'Linux' }}
        # https://github.com/easimon/maximize-build-space/blob/master/action.yml
        # https://github.com/easimon/maximize-build-space/tree/test-report
        run: |
          sudo rm -rf /usr/share/dotnet
          sudo rm -rf /opt/ghc
          sudo rm -rf /opt/hostedtoolcache/CodeQL

      - name: Set up swap space
        if: runner.os == 'Linux'
        uses: pierotofy/set-swap-space@v1.0
        with:
          swap-size-gb: 10

      - name: Install CUDA ${{ matrix.cuda-version }}
        if: ${{ matrix.cuda-version != 'cpu' }}
        uses: Jimver/cuda-toolkit@v0.2.14
        id: cuda-toolkit
        with:
          cuda: ${{ matrix.cuda-version }}
          linux-local-args: '["--toolkit"]'
          # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
          # method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
          method: 'network'
          # We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
          # not just nvcc
          # sub-packages: '["nvcc"]'

      - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
        run: |
          pip install --upgrade pip
          # If we don't install before installing Pytorch, we get error for torch 2.0.1
          # ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none)
          pip install lit
          # We want to figure out the CUDA version to download pytorch
          # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
          # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
          # This code is ugly, maybe there's a better way to do this.
          export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
            minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118}[env['MATRIX_TORCH_VERSION']]; \
            maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121}[env['MATRIX_TORCH_VERSION']]; \
            print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \
          )
          if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
            if [[ ${MATRIX_TORCH_VERSION} == "2.2" ]]; then
              # --no-deps because we can't install old versions of pytorch-triton
              pip install typing-extensions jinja2
              pip install --no-cache-dir --no-deps --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
            else
              pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
            fi
          else
            pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
          fi
          nvcc --version
          python --version
          python -c "import torch; print('PyTorch:', torch.__version__)"
          python -c "import torch; print('CUDA:', torch.version.cuda)"
          python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
        shell:
          bash

      - name: Build wheel
        run: |
          # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
          # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
          # However this still fails so I'm using a newer version of setuptools
          pip install setuptools==68.0.0
          pip install ninja packaging wheel
          export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
          export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
          # Limit MAX_JOBS otherwise the github runner goes OOM
          MAX_JOBS=2 FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
          tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
          wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
          ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
          echo "wheel_name=${wheel_name}" >> $GITHUB_ENV

      - name: Log Built Wheels
        run: |
          ls dist

      - name: Get the tag version
        id: extract_branch
        run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}

      - name: Get Release with tag
        id: get_current_release
        uses: joutvhu/get-release@v1
        with:
          tag_name: ${{ steps.extract_branch.outputs.branch }}
        env:
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

      - name: Upload Release Asset
        id: upload_release_asset
        uses: actions/upload-release-asset@v1
        env:
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
        with:
          upload_url: ${{ steps.get_current_release.outputs.upload_url }}
          asset_path: ./dist/${{env.wheel_name}}
          asset_name: ${{env.wheel_name}}
          asset_content_type: application/*

  publish_package:
    name: Publish package
    needs: [build_wheels]

    runs-on: ubuntu-latest

    steps:
      - uses: actions/checkout@v3

      - uses: actions/setup-python@v4
        with:
          python-version: '3.10'

      - name: Install dependencies
        run: |
          pip install ninja packaging setuptools wheel twine
          # We don't want to download anything CUDA-related here
          pip install torch --index-url https://download.pytorch.org/whl/cpu

      - name: Build core package
        env:
          FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE"
        run: |
          python setup.py sdist --dist-dir=dist

      - name: Deploy
        env:
          TWINE_USERNAME: "__token__"
          TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
        run: |
          python -m twine upload dist/*
